from typing import List, Dict, Any, Tuple
from typing import (
    Annotated,
    Sequence,
    TypedDict,
)
import sys

import json
import os
import asyncio
from typing import TypedDict, Optional, Literal
from utils import extract_json,cleanup_temp_images
from download_data import * 
from prompt_lib_qwen import *
from model_wrapper1 import Message, QwenChatModel
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
import numpy as np
model = QwenChatModel(model_name= "qwen2.5-full")
class Plan(TypedDict, total=False):
    distortion_detection: bool
    distortion_analysis: bool
    tool_selection: bool
    tool_execute: bool

class IQAState(TypedDict):
    query: str
    task_type: str
    images: List[str]
    required_distortions: List[str]
    distortion_analysis: Optional[Dict[str, List[Dict[str, str]]]]
    reference_type: str
    quality_reasoning: str
    required_tool: List[str]
    required_object_names: Optional[List[str]]
    choices_list: List[str]
    distortion_source: Optional[str]
    quality_scores: Dict[str, Dict[str, Tuple[str, float]]]
    final_answer: Optional[float]
    plan: Optional[Plan]
    summary: Optional[str]
    error: Optional[str]
    messages: Optional[List[Message]]

async def planner_step(state: IQAState) -> IQAState:
    try:
        question = state.get('query', "")
        choices_list = state.get('choices_list', "")
        images = state.get("images", [])
        
        if isinstance(choices_list, list):
            query = question + "\n" + "\n".join(choices_list)
        else:
            query = question

        system_prompt = build_planner_prompt()
        messages = [
                    Message("system", [{"type": "text", "text": system_prompt}]),
                    Message("user",[{"type": "text", "text": query}]),
                ]

        response = await model.ainvoke(messages)
        parsed_output = extract_json(response)
        print("[Planner Output]", parsed_output)

        if parsed_output:
            state.update(parsed_output)
        else:
            state["error"] = {
                "message": "Failed to parse JSON response from Planner.",
                "image_paths": images
            }
        return state

    except Exception as e:
        print(f"[Planner Error] {e}")
        state["error"] = {
            "message": f"Planner exception: {e}",
            "image_paths": state["images"]
        }
        return state

async def planner_step_with_retry(state: IQAState, max_retries=2):
    for attempt in range(max_retries):
        result = await planner_step(state)
        if (not result.get("error")) and result.get("plan"):
            return result
        print(f"[Planner Retry] Attempt {attempt + 1} failed.")
        await asyncio.sleep(1)

    state["error"] = {
        "message": "Planner failed after retries.",
        "image_paths": state["images"]
    }
    return state

# ======================== Distortion Step ========================
async def distortion_step(state: IQAState):
    plan = state.get("plan", {})
    detection_required = plan.get("distortion_detection", True)
    analysis_required = plan.get("distortion_analysis", True)
    choices_list = state.get("choices_list", [])
    choice_dict  = {chr(65+i): v for i,v in enumerate(choices_list)} if choices_list else None
    user_question=(
                        f"Question: {state['query']}\n\n"
                        "Answer choices:\n" + "\n".join(f"{k}. {v}" for k,v in (choice_dict or {}).items())
                    ) 
    try:
        dist_path = state["images"][0]
        ref_path = state["images"][1] if state.get("reference_type") == "Full-Reference" else None
        is_object_level = state.get("required_object_names") not in (None, [None], [])

        image_contents = [{"type": "image", "image": dist_path}]
        if ref_path:
            image_contents.append({"type": "image", "image": ref_path})
        img_expl = (
            "The first image is the distorted image. The second image is the reference image."
            if ref_path else
            "This image is the distorted image."
        )

        # ===== Step 1: Distortion Detection =====
        if detection_required:
            print("[Distortion Detection] Running...")
            sys_prompt = build_distortion_detection_prompt(is_object_level,user_question)
            user_query   = (
                f"{img_expl}\n\n{sys_prompt}\n\n"
                + (
                    f"Please analyze specific regions {state.get('required_object_names')} of the distorted image."
                    if is_object_level else
                    "Please analyze the distorted image."
                )
            )
            messages = [
                Message("system", [{"type": "text", "text": sys_prompt}]),
                Message("user", image_contents + [{"type": "text", "text": user_query}])
            ]

            response = await model.ainvoke(messages)
            parsed = extract_json(response)
            print("[Detection Result]", parsed)

            if parsed:
                state["required_distortions"] = parsed
            else:
                print("[Distortion Detection] No valid distortions, fallback to empty.")
                state["required_distortions"] = {"Global": []}
        else:
            print("[Distortion Detection] Skipped")

        # ===== Step 2: Distortion Analysis =====
        if analysis_required:
            print("[Distortion Analysis] Running...")
            distortion_dict = state.get("required_distortions", {})
            all_empty = all(isinstance(v, list) and len(v) == 0 for v in distortion_dict.values())
            if all_empty:
                print("[Distortion Analysis] No distortion detected, skipping LLM analysis.")
                state["distortion_analysis"] = {
                    k: [{"type": "None", "severity": "None", "explanation": "No visible distortions detected."}]
                    for k in distortion_dict.keys()
                }
                return state

            sys_prompt2 = build_distortion_analysis_prompt_multi_object(
                distortion_dict,
                has_reference=bool(ref_path),
                user_question=user_question
            )
            user_query2 = (
                f"{img_expl}\n\n{sys_prompt2}\n\n"
                "Please analyze the following regions and their corresponding distortions."
            )
            messages = [
                Message("system", [{"type": "text", "text": sys_prompt2}]),
                Message("user", image_contents + [{"type": "text", "text": user_query2}])
            ]
            response2 = await model.ainvoke(messages)
            parsed2   = extract_json(response2)
            print("[Analysis Result]", parsed2)

            if parsed2:
                state["distortion_analysis"] = parsed2
            else:
                state["error"] = {"message": "Failed to parse analysis result"}
                return state
        else:
            print("[Distortion Analysis] Skipped")

    except Exception as e:
        state["error"] = {"message": f"Distortion step failed: {e}"}
    return state

# ======================== Tool Selection + Execution ========================
from subprocess_tool import execute_tool_by_name, tool_executor

TOOL_EXEC_SEMAPHORE = asyncio.Semaphore(1)
def build_tool_call_args_from_schema(tool_name: str, state: IQAState) -> Dict[str, str]:
    params = TOOL_DESCRIPTIONS[tool_name]["parameters"]["required"]
    args = {}
    if "reference_image" in params:
        args["reference_image"] = state["images"][1]
    if "distorted_image" in params:
        args["distorted_image"] = state["images"][0]
    if "image" in params:
        args["image"] = state["images"][0]
    return args

async def tool_selection_step(state: IQAState) -> IQAState:
    try:
        distortion_dict = state.get("required_distortions", {})
        ref_type = state.get("reference_type", "No-Reference")

        sys_prompt = build_tool_prompt(ref_type)
        user_query = f"Distortions: {distortion_dict}"

        messages = [
            Message("system", [{"type": "text", "text": sys_prompt}]),
            Message("user", [{"type": "text", "text": user_query}])
        ]
        response = await model.ainvoke(messages)
        parsed = extract_json(response)
        print("[Tool Selection]", parsed)

        # ===== Handle Missing or Invalid Output =====
        if not parsed or parsed in ["n/a", "none", {}, None]:
            print("[Tool Selection] Fallback triggered due to empty result.")
            fallback_tool = "TopIQ_FR_tool" if ref_type == "Full-Reference" else "QAlign_tool"
            state["selected_tools"] = {"Global": {"default": fallback_tool}}
            state.pop("error", None)
            return state

        # ===== Replace 'N/A' entries with fallback tools =====
        fallback_tool = "TopIQ_FR_tool" if ref_type == "Full-Reference" else "QAlign_tool"
        for obj_name, dist_tool_map in parsed.items():
            if isinstance(dist_tool_map, list):
                # Convert list of dicts to flat dict
                dist_tool_map_flat = {}
                for pair in dist_tool_map:
                    for dist_type, tool_name in pair.items():
                        if tool_name.upper() in ["N/A", "NA", "NONE"]:
                            dist_tool_map_flat[dist_type] = fallback_tool
                        else:
                            dist_tool_map_flat[dist_type] = tool_name
                parsed[obj_name] = dist_tool_map_flat
            elif isinstance(dist_tool_map, dict):
                for dist_type, tool_name in dist_tool_map.items():
                    if tool_name.upper() in ["N/A", "NA", "NONE"]:
                        dist_tool_map[dist_type] = fallback_tool

        state["selected_tools"] = parsed
        state.pop("error", None)
        return state

    except Exception as e:
        state["error"] = {"message": f"Tool selection failed: {e}"}
        return state

async def tool_execution_step(state: IQAState) -> IQAState:
    quality_scores = {}
    required_tool_list = state.get("required_tool")
    execution_cache = {}  # 缓存基于 tool_name，忽略 distortion

    # ===== Required Tool Path =====
    if required_tool_list:
        quality_scores["Global"] = {}
        for tool_name in required_tool_list:
            if tool_name not in TOOL_DESCRIPTIONS:
                state["error"] = {"message": f"[Tool Error] Tool '{tool_name}' not found."}
                return state
            try:
                args = build_tool_call_args_from_schema(tool_name, state)

                if tool_name not in execution_cache:
                    score = await tool_executor.run(tool_name, args)
                    execution_cache[tool_name] = score
                else:
                    score = execution_cache[tool_name]

                quality_scores["Global"][tool_name] = score
            except Exception as e:
                state["error"] = {"message": f"[Execution Error] {tool_name}: {e}"}
                return state

        state["quality_scores"] = quality_scores
        state.pop("error", None)
        return state

    # ===== Selected Tool Path =====
    selected = state.get("selected_tools", {})
    try:
        for obj_name, dist_map in selected.items():
            if isinstance(dist_map, list):
                new_map = {}
                for pair in dist_map:
                    for dist_type, tool_name in pair.items():
                        new_map[dist_type] = tool_name
                dist_map = new_map

            for dist_type, tool_name in dist_map.items():
                tool_name = tool_name.replace("functions.", "")
                if tool_name not in TOOL_DESCRIPTIONS:
                    state["error"] = {"message": f"[Tool Error] Tool '{tool_name}' not found."}
                    return state
                try:
                    args = build_tool_call_args_from_schema(tool_name, state)
                    args["distortion"] = dist_type  

                    if tool_name not in execution_cache:
                        score = await tool_executor.run(tool_name, args)
                        execution_cache[tool_name] = score
                    else:
                        score = execution_cache[tool_name]

                    quality_scores.setdefault(obj_name, {})[dist_type] = (tool_name, score)
                except Exception as e:
                    state["error"] = {"message": f"[Execution Error] {tool_name} for {obj_name}/{dist_type}: {e}"}
                    return state

        state["quality_scores"] = quality_scores
        state.pop("error", None)
        return state

    except Exception as e:
        state["error"] = {"message": f"Tool execution failed: {e}"}
        return state

async def tool_step(state: IQAState):
    plan = state.get("plan", {})
    do_selection = plan.get("tool_selection", True)
    do_execute = plan.get("tool_execute", True)

    if state.get("required_distortions") == {"Global": []}:
        print("[Tool Step] Skipped due to no visible distortions.")
        state["selected_tools"] = None
        state["quality_scores"] = {"Global": {"default": 5.0}}
        state.pop("error", None)
        return state

    # Step 1: Tool Selection
    if do_selection:
        state = await tool_selection_step(state)
        if state.get("error"):
            return state
    else:
        print("[Tool Selection] Skipped")

    # Step 2: Tool Execution
    if do_execute:
        state = await tool_execution_step(state)
        if state.get("error"):
            return state
    else:
        print("[Tool Execution] Skipped")

    state.pop("error", None)
    return state
# ======================== Summarizer ========================
ABC_CHOICES = ['A', 'B', 'C', 'D', 'E']
ABC_WEIGHTS = [5, 4, 3, 2, 1]
def extract_choice_logprobs(top_logprobs) -> Dict[str, float]:
    logprobs_dict = {}

    for choice in ABC_CHOICES:
        matching_items = [
            item for item in top_logprobs
            if item.token.strip().startswith(choice)
        ]
        if matching_items:
            best = max(matching_items, key=lambda x: x.logprob)
            logprobs_dict[choice] = best.logprob
        else:
            logprobs_dict[choice] = -100.0  # default logprob for missing

    return logprobs_dict

def extract_probs_from_logprobs(logprobs_dict):
    logprobs = [logprobs_dict.get(k, -100.0) for k in ['A', 'B', 'C', 'D', 'E']]
    exp_logprobs = np.exp(logprobs)
    probs = exp_logprobs / np.sum(exp_logprobs)
    return probs[::-1]  # E→A → 1~5

def compute_hvs_weights(avg_q, num_levels=5, b=1.0):
    i = np.arange(1, num_levels + 1)
    w = np.exp(-b * (avg_q - i) ** 2)
    return w / np.sum(w)

def compute_hvs_final_score(logprobs_dict, avg_q, b=1.0):
    probs = extract_probs_from_logprobs(logprobs_dict)
    alpha = compute_hvs_weights(avg_q, num_levels=5, b=b)
    v_i = np.sum(alpha * np.arange(1, 6))  # weighted center
    return float(np.sum(probs * v_i))  

def format_distortion_text(detected: Dict[str, List], analyzed: Optional[Dict[str, List]] = None) -> str:
    result_lines = []
    for obj, det_list in detected.items():
        result_lines.append(f"{obj}:")
        ana_list = analyzed.get(obj, []) if analyzed else []
        if ana_list and isinstance(ana_list[0], dict):
            for d in ana_list:
                result_lines.append(f"- {d['type']} ({d['severity']}): {d['explanation']}")
        elif det_list:
            for d in det_list:
                result_lines.append(f"- {d}")
        else:
            result_lines.append("No visible distortions.")
    return "\n".join(result_lines)  

def summarize_query_prompt(state: IQAState, choice_dict: Dict[str, str]) -> str:
    query_info = state.get("query", "")
    object_names = state.get("required_object_names", [])
    quality_scores = state.get("quality_scores", {})
    reference_type = state.get("reference_type", "No-Reference")

    if reference_type == "Full-Reference":
        ref_text = "Note: The first image is the distorted image, and the second is the reference image.\n\n"
    else:
        ref_text = ""

    # Distortion Text
    distortion_text = format_distortion_text(
        detected=state.get("required_distortions", {}),
        analyzed=state.get("distortion_analysis", {})
    )

    # Tool Response Text
    tool_text = ""
    if quality_scores:
        for obj_name, dist_map in quality_scores.items():
            tool_text += f"{obj_name}:\n"
            for dist_type, value in dist_map.items():
                if isinstance(value, tuple):
                    score = float(value[1])
                elif isinstance(value, (int, float)):
                    score = float(value)
                else:
                    score = None
                if score is not None:
                    tool_text += f"- {dist_type}: {score:.2f}\n"
                else:
                    tool_text += f"- {dist_type}: [Invalid score]\n"

    # === Local + Choice ===
    if object_names and choice_dict:
        return (
            ref_text +
            f"Question: {query_info}\n\n"
            f"Answer choices:\n" +
            "\n".join([f"{k}. {v}" for k, v in choice_dict.items()]) + "\n\n"
            f"Distortion analysis:\n{distortion_text}\n\n"
        )

    # === Global + Choice ===
    elif not object_names and choice_dict:
        return (
            ref_text +
            f"Question: {query_info}\n\n"
            f"Answer choices:\n" +
            "\n".join([f"{k}. {v}" for k, v in choice_dict.items()]) + "\n\n"
            f"Tool response:\n{tool_text} (where 1=Bad, 2=Poor, 3=Fair, 4=Good, 5=Excellent)\n\n"
            f"Distortion analysis:\n{distortion_text}"
        )

    # === Global + No Choice ===
    elif not object_names and not choice_dict:
        return (
            ref_text +
            f"Question: {query_info}\n\n"
            f"Answer choices:\nA. Excellent\nB. Good\nC. Fair\nD. Poor\nE. Bad\n\n"
            f"Tool response (predicted perceptual quality under specific distortions, where 5 is best and 1 is worst):\n{tool_text}\n\n"
            f"Distortion analysis:\n{distortion_text}"
        )

    # === Local + No Choice ===
    else:
        return "Unsupported: local object + open-ended query."

async def summarize_quality_step(state: IQAState) -> IQAState:
    MAX_RETRIES = 2
    for attempt in range(MAX_RETRIES):
        try:
            query       = state.get("query", "")
            qtype       = state.get("task_type", "others").lower()
            choices     = state.get("choices_list", []) or []
            has_choices = bool(choices)
            objects     = state.get("required_object_names") or []
            qscores     = state.get("quality_scores", {})
            choice_dict = {chr(65 + i): v for i, v in enumerate(choices)} if choices else {}
            user_question=(
                                f"Question: {state['query']}\n\n"
                                "Answer choices:\n" + "\n".join(f"{k}. {v}" for k,v in (choice_dict or {}).items())
                            )  
            
            dist = state["images"][0]
            ref  = state["images"][1] if state.get("reference_type")=="Full-Reference" else None
            image_contents = [{"type": "image", "image": dist}]
            
            if ref:
                image_contents.append({"type": "image", "image": ref})

            if (qtype == "other" or qtype == "others") and has_choices:
                system_prompt = build_summary_others_prompt(choices)
                user_text = (
                    f"Question: {query}\n\n"
                    "Answer choices:\n" +
                    "\n".join(f"{chr(65+i)}. {c}" for i,c in enumerate(choices))
                )
            elif qtype == "iqa" and has_choices:
                system_prompt = build_summary_choice_prompt(choices,user_question)
                user_text     = summarize_query_prompt(state, choice_dict)
            elif qtype == "iqa" and not has_choices and objects:
                state["final_answer"] = "Unsupported: local object + open-ended query."
                print("[Summarizer] Unsupported local object + no choices")
                return state
            elif qtype == "iqa" and not has_choices and not objects:
                # —— Numeric HVS path via Qwen logits ——
                system_prompt = build_summary_nochoice_prompt()
                user_text     = f"Question: {query}"

                # build inputs and get logits
                messages_for_logits = [
                    Message("system", [{"type": "text", "text": system_prompt}]),
                    Message("user",   image_contents + [{"type": "text", "text": user_text}])
                ]
                inputs_for_logits = model._build_inputs(messages_for_logits)
                import torch.nn.functional as F
                with torch.no_grad():
                    outputs     = model.model(**inputs_for_logits)
                    last_logits = outputs.logits[:, -1, :]
                    log_probs   = F.log_softmax(last_logits, dim=-1)[0]

                # extract A–E log-probs
                lp_dict = {
                    ch: log_probs[model.processor.tokenizer.convert_tokens_to_ids(ch)].item()
                    for ch in ABC_CHOICES
                }
                # aggregate previous scores
                scores = [
                    float(v[1]) if isinstance(v, tuple) else float(v)
                    for dm in qscores.values() for v in dm.values()
                ]
                avg_q = float(np.mean(scores)) if scores else 0.0
                final_score = compute_hvs_final_score(lp_dict, avg_q)
                print(f"[Summarizer] {final_score!r}")
                state["final_answer"] = final_score
                return state

            else:
                state["final_answer"] = "Unsupported: open-ended query."
                return state

            # common choice-based path
            messages = [
                Message("system", [{"type": "text", "text": system_prompt}]),
                Message("user", image_contents + [{"type": "text", "text": user_text}])
            ]

            response = await model.ainvoke(messages)
            parsed = extract_json(response)

            if parsed and isinstance(parsed, dict) and "final_answer" in parsed:
                state["final_answer"] = parsed.get("final_answer")
                state["quality_reasoning"] = parsed.get("quality_reasoning", "")
            else:
                state["final_answer"] = response.strip()
                state["quality_reasoning"] = "Generated from plain text without JSON."

            print(f"[Summarizer Final] {state['final_answer']}")
            return state
        except Exception as e:
            state["error"] = {"message": f"Summarizer failed: {e}"}
            await asyncio.sleep(1)
    state["error"] = {"message": "Summarizer failed after retries"}
    return state


# # Async Function to Run IQA Analysis
async def run_iqa_analysis(question, images,choices_list=None):
    """Runs the complete IQA analysis using LangGraph."""
    state = IQAState(
        query=question,
        task_type=None,
        images=images,
        choices_list=choices_list or [],
        distortions=None,
        reference_type="Full-Reference" if len(images) > 1 else "No-Reference",
        required_tool=None,
        required_object_names=None,
        distortion_source=None,
        quality_scores=None,
        final_answer=None,
        plan=None,
        error=None,
        quality_reasoning=None,
        messages=[]
    )
    # model = ChatModel(model="gpt-4o", temperature=0)

    # step 1: planner
    await planner_step_with_retry(state)
    if state.get("error"): return state
    
    # step 2: distortion detection & analyze
    state = await distortion_step(state)
    if state.get("error"): return state

    # step 3: tool selection & executor
    state = await tool_step(state)
    if state.get("error"): return state

    # step 4: summarizer
    state = await summarize_quality_step(state)
    return state

async def run_iqa_analysis_with_retry(question, images, choices_list=None, max_retries=2):
    for attempt in range(max_retries):
        state = await run_iqa_analysis(question, images, choices_list)
        if not state.get("error"):
            return state
        print(f"[Retry {attempt+1}] Error encountered: {state['error']['message']}")
    print("[Final Failure] All retries failed.")
    return state

# Run IQA Analysis
async def main():
    output_path = "/root/IQA/IQA-Agent/AgenticIQA/results/qbench_ours_test.json"
    errors_path = "/root/IQA/IQA-Agent/AgenticIQA/results/qbench_ours_test_error.json"

    with open(output_path, "w", encoding="utf-8") as f:
        f.write("[\n")
    with open(errors_path, "w", encoding="utf-8") as d:
        d.write("[\n")


    img_paths, questions, choices_list, correct_choices,types, concerns = QBench_data()
    # img_paths, questions, choices_list, correct_choices = QBench_error_recovery_data()
    for i, (img_path, question, choices, correct,type,concern) in enumerate(zip(img_paths, questions, choices_list, correct_choices,types, concerns)):
        print(f"[{i}] Running analysis...")
        output_state = await run_iqa_analysis_with_retry(question, img_path, choices)
        result = {
            "correct_choice": correct,
            "type":type,
            "concern":concern,
            "state": output_state
        }       
        if "messages" in result["state"]:
            result["state"].pop("messages")

        if output_state.get("error"):
            output_state["error"]["image_paths"] = img_path
            output_state["error"]["query"] = question
            with open(errors_path, "a", encoding="utf-8") as d:
                json.dump(output_state["error"], d, ensure_ascii=False, indent=2)
                d.write(",\n" if i != len(img_paths) - 1 else "\n")
        else:
            with open(output_path, "a", encoding="utf-8") as f:
                json.dump(result, f, ensure_ascii=False, indent=2)
                f.write(",\n" if i != len(img_paths) - 1 else "\n")
        cleanup_temp_images()
    with open(output_path, "a", encoding="utf-8") as f:
        f.write("]\n")
    with open(errors_path, "a", encoding="utf-8") as d:
        d.write("]\n")   

if __name__ == "__main__":
    asyncio.run(main())


